Skip to content

Conversation

@yaoyaoding
Copy link
Member

@yaoyaoding yaoyaoding commented Nov 4, 2025

This PR adds the matmul implementation of tilus for blackwell architecture, with warp specialization optimization. This version of matmul achieves over 85% performance of cublas.

       m      n      k   name  latency (ms)       tflops
0   4096   4096   4096  torch      0.160752   854.975085
1   4096   4096   4096  tilus      0.179200   766.958441
2   4096   4096  14336  torch      0.460816  1043.879426
3   4096   4096  14336  tilus      0.494528   972.718110
4   8192   8192   8192  torch      0.909792  1208.530764
5   8192   8192   8192  tilus      1.022016  1075.826186
6  10240  10240  10240  torch      1.710224  1255.673881
7  10240  10240  10240  tilus      1.996880  1075.419481
@tilus.autotune("block_m, block_n", [[128, 64], [128, 128], [128, 256]])
@tilus.autotune("block_k", [16, 32, 64])
@tilus.autotune("stages", [2, 3, 4])
class BlackwellMatmul(tilus.Script):
    def __init__(self, block_m: int, block_n: int, block_k: int, stages: int):
        super().__init__()
        self.block_m = block_m
        self.block_n = block_n
        self.block_k = block_k
        self.stages = stages

    def __call__(
        self,
        m_size: int32,
        n_size: int,
        k_size: int,
        a_ptr: ~float16,
        b_ptr: ~float16,
        c_ptr: ~float16,
    ):
        self.attrs.blocks = [cdiv(m_size, self.block_m), cdiv(n_size, self.block_n)]
        self.attrs.warps = 4

        offset_m: int32 = self.block_m * self.blockIdx.x
        offset_n: int32 = self.block_n * self.blockIdx.y

        g_a = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size])
        g_b = self.global_view(b_ptr, dtype=float16, shape=[n_size, k_size])
        s_a = self.shared_tensor(
            dtype=float16, shape=[self.stages, self.block_m, self.block_k]
        )
        s_b = self.shared_tensor(
            dtype=float16, shape=[self.stages, self.block_n, self.block_k]
        )

        # allocate a tensor in tensor memory (tmem)
        t_acc = self.tcgen05.alloc(
            dtype=float32, shape=[self.block_m, self.block_n], init=0.0
        )

        # allocate barriers and the initial phases
        consumer_barriers = self.mbarrier.alloc(
            count=[1 for _ in range(self.stages)]
        )  # whether the data is ready for consumption
        producer_barriers = self.mbarrier.alloc(
            count=[1 for _ in range(self.stages)]
        )  # whether the data is ready to be filled

        with self.thread_group(group_index=0, group_size=32):
            # tma warp
            stage: int32 = 0
            producer_phases = self.register_tensor(
                dtype=uint32, shape=[self.stages], init=1
            )  # all stages are ready to be filled at the beginning
            for offset_k in self.range(0, k_size, self.block_k, unroll=self.stages):
                self.mbarrier.wait(
                    producer_barriers[stage], phase=producer_phases[stage]
                )  # wait until the stage is ready to be filled
                producer_phases[stage] ^= 1
                with self.single_thread():
                    self.tma.global_to_shared(
                        src=g_a,
                        dst=s_a[stage],
                        offsets=[offset_m, offset_k],
                        mbarrier=consumer_barriers[stage],
                    )
                    self.tma.global_to_shared(
                        src=g_b,
                        dst=s_b[stage],
                        offsets=[offset_n, offset_k],
                        mbarrier=consumer_barriers[stage],
                    )
                    self.mbarrier.arrive(consumer_barriers[stage])
                stage = (stage + 1) % self.stages

            # remaining mma stages to wait for completion
            for _ in self.range(min(self.stages, cdiv(k_size, self.block_k))):
                self.mbarrier.wait(
                    producer_barriers[stage], phase=producer_phases[stage]
                )  # wait until the stage is ready to be filled
                producer_phases[stage] ^= 1
                stage = (stage + 1) % self.stages

        with self.thread_group(group_index=1, group_size=32):
            # mma warp
            consumer_phases = self.register_tensor(
                dtype=uint32, shape=[self.stages], init=0
            )  # all stages are not ready for consumption at the beginning
            stage: int32 = 0
            for offset_k in self.range(0, k_size, self.block_k, unroll=self.stages):
                self.mbarrier.wait(
                    consumer_barriers[stage], phase=consumer_phases[stage]
                )  # wait until the stage is ready for consumption
                consumer_phases[stage] ^= 1
                with self.single_thread():
                    self.tcgen05.mma(s_a[stage], s_b[stage].transpose(), t_acc)
                    self.tcgen05.commit(mbarrier=producer_barriers[stage])
                stage = (stage + 1) % self.stages

        self.sync()

        # load the result from tensor memory to register
        r_acc = self.tcgen05.load(
            t_acc, offsets=[0, 0], shape=[self.block_m, self.block_n]
        )

        g_c = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size])
        self.store_global(g_c, r_acc.to(float16), offsets=[offset_m, offset_n])

        # all allocated tensor memory must be deallocated
        self.sync()
        self.tcgen05.dealloc(t_acc)

Minors:

  1. enhance the scripts/sign-commits.py utility to only rebase from the unsigned commit, instead of the main branch.
  2. add BarrierAllocContext and SyncContext in codegen to unify mbarrier allocation and sub-group syncrhonization (that based on mbarrier).

Signed-off-by: Yaoyao Ding <[email protected]>
Signed-off-by: Yaoyao Ding <[email protected]>
Signed-off-by: Yaoyao Ding <[email protected]>
Signed-off-by: Yaoyao Ding <[email protected]>
Signed-off-by: Yaoyao Ding <[email protected]>
Signed-off-by: Yaoyao Ding <[email protected]>
Signed-off-by: Yaoyao Ding <[email protected]>
Signed-off-by: Yaoyao Ding <[email protected]>
Signed-off-by: Yaoyao Ding <[email protected]>
Signed-off-by: Yaoyao Ding <[email protected]>
@yaoyaoding yaoyaoding mentioned this pull request Nov 4, 2025
12 tasks
@yaoyaoding yaoyaoding merged commit 2b74996 into main Nov 4, 2025
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants